# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
 Produce the build model
"""

import sys
import mindvision.classification.models.backbones as BACKBONES
import mindvision.classification.models.neck as NECK
import mindvision.classification.models.head as HEAD
from .classifiers.base import BaseClassifier


def build_model(config):
    """
    Args:
        config: model_name

    Returns: network model by config.model_name

    """
    modb = sys.modules[BACKBONES.__name__]
    modn = sys.modules[NECK.__name__]
    modh = sys.modules[HEAD.__name__]

    try:
        backbone = getattr(modb, config.BACKBONE.type, None)
        assert backbone, "No backbone named {}.".format(config.BACKBONE.type)
        if backbone:
            backbone_param = config.BACKBONE.params
            backbone = backbone(**backbone_param) if backbone_param else backbone()
    except AttributeError:
        backbone = None

    try:
        neck = getattr(modn, config.NECK.type, None)
        assert neck, "No neck named {}.".format(config.NECK.type)
        if neck:
            neck_param = config.NECK.params
            neck = neck(**neck_param) if neck_param else neck()
    except AttributeError:
        neck = None

    try:
        head = getattr(modh, config.HEAD.type, None)
        assert head, "No head named {}.".format(config.HEAD.type)
        if head:
            head_param = config.HEAD.params
            head = head(**head_param) if head_param else head()
    except AttributeError:
        head = None

    net = BaseClassifier(backbone=backbone, neck=neck, head=head)

    return net
